[JAX] [PyT] [Common] Enable D=256 BWD cuDNN fused attn for Blackwell CC 10.x #3056
[JAX] [PyT] [Common] Enable D=256 BWD cuDNN fused attn for Blackwell CC 10.x #3056KshitijLakhani wants to merge 15 commits into
Conversation
51ad582 to
d177ecf
Compare
Greptile SummaryThis PR enables D=256 backward (bprop) support for Blackwell SM10x GPUs via the cuDNN deterministic SDPA bprop kernel introduced in cuDNN 9.23 / FE 1.24. A new guard is added to the
Confidence Score: 4/5Safe to merge for non-THD users; the known THD + D=256 cuDNN plan-build failure is documented with a strict xfail test rather than a backend exclusion, leaving a latent hard exception for production THD + D=256 training on SM10x. The C++ backend gate correctly activates the new kernel path for BSHD/SBHD layouts and the test suite validates the main happy-path configurations. The THD layout issue remains unguarded in the backend selector, leaving a latent hard runtime exception for production THD + D=256 training workloads on SM10x. transformer_engine/common/fused_attn/fused_attn.cpp — the new D=256 guard does not exclude THD-format layouts, which cannot build a cuDNN 9.23 execution plan. Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A["nvte_get_fused_attn_backend()"] --> B{dtype FP16/BF16?}
B -->|yes| C{flag_arb conditions}
B -->|no| Z[Other backends / FP8]
C --> D{arch & version gate}
D -->|SM80/90 paths| D2[Earlier gates]
D -->|cuDNN ge 9.7 and SM ge 100| E{head_dim check}
E --> E1["d le 128 always"]
E --> E2["d le 256 + Hopper cuDNN ge 9.1/9.5"]
E --> E3["any d + Blackwell fprop cuDNN ge 9.9"]
E --> E4["d_qk=192 d_v=128 + Blackwell bprop cuDNN ge 9.11"]
E --> E5{"NEW: d_qk=d_v=256 + SM10x bprop cuDNN ge 9.23"}
E5 --> G{no_bias and no_dropout and vanilla_softmax and non-paged and window_size OK?}
G -->|no| F[Fall through]
G -->|yes| H[flag_arb = true]
H --> I{outer mask/format/SWA checks}
I -->|fail| F
I -->|pass| J[Return NVTE_F16_arbitrary_seqlen]
Reviews (3): Last reviewed commit: "[pre-commit.ci] auto fixes from pre-comm..." | Re-trigger Greptile |
| # vanilla type of softmax only, no dropout, no ALiBi, and (for non-causal masks) full-window attention only. | ||
| # (for non-causal masks) full-window attention. |
There was a problem hiding this comment.
The comment block ends with a repeated phrase: line 383 (# (for non-causal masks) full-window attention.) is a verbatim fragment of line 382, left over from editing. It should be removed.
| # vanilla type of softmax only, no dropout, no ALiBi, and (for non-causal masks) full-window attention only. | |
| # (for non-causal masks) full-window attention. | |
| # vanilla type of softmax only, no dropout, no ALiBi, and (for non-causal masks) full-window attention only. |
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
There was a problem hiding this comment.
Seems like some editing glitch :)
| # Non-learnable bias is fine (bias is allowed as an input); only dBias is | ||
| # unsupported. The JAX runner asks for dBias iff the bias shape is [1, h, s, s] | ||
| # (see test_backward), so gate on that. | ||
| unsupported = None | ||
| if self.attn_bias_type == AttnBiasType.PRE_SCALE_BIAS: | ||
| unsupported = "pre-scale bias" | ||
| elif self.attn_bias_type != AttnBiasType.NO_BIAS and self.bias_shape == BiasShape._1HSS: | ||
| unsupported = ( | ||
| "bias gradients (dBias); frozen/non-learnable bias inputs" | ||
| " (i.e. non-1HSS bias shapes) are supported" | ||
| ) |
There was a problem hiding this comment.
JAX skip logic diverges from C++ backend gate for non-1HSS bias
The comment says "frozen/non-learnable bias inputs (i.e. non-1HSS bias shapes) are supported" and the skip block deliberately allows those configs to proceed. However, the C++ gate in fused_attn.cpp requires bias_type == NVTE_NO_BIAS for the new D=256 BWD path, meaning any config with attn_bias_type != NO_BIAS && bias_shape != _1HSS will silently fall back to a different backend rather than exercising the new kernel. The test will not fail, but it also will not validate the D=256 BWD path for those configs, and the inline comment creates a misleading expectation that such configs are actually routed through it.
| attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK || | ||
| attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK || | ||
| attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK) && | ||
| (window_size_right == -1 || window_size_right == 0)))) || |
There was a problem hiding this comment.
Could these changes be moved to before "\ bias type" just so it's following an increasing order of the cuDNN version?
There was a problem hiding this comment.
Does this new feature support BSHD/SBHD and THD? It looks like the tests are focused on BSHD/SBHD only.
There was a problem hiding this comment.
RE: THD support
I did test BSHD and BSHD+CP and it did pass on the JAX side and the CI for the PyT side did not fail either so I think that works.
My testing revealed that THD support is not yet available (Bwd plan compialtion issue) so I've filed a bug and shared a reproducer for the same with the cuDNN team: NVIDIA/cudnn-frontend#276
There was a problem hiding this comment.
Could these changes be moved to before "\ bias type" just so it's following an increasing order of the cuDNN version?
Fixed in a264de1
| "D=256 BWD on Blackwell only supports right window -1 or 0" | ||
| " for causal masks." | ||
| ) | ||
|
|
There was a problem hiding this comment.
Aren't these checks duplicate to the checks we added on the C++ side? Would the call FusedAttnHelper().get_fused_attn_backend() give you the same gating effect?
There was a problem hiding this comment.
So if we are just interested in the gating effect, you are right. The get_fused_attn_backend() will return NVTE_No_Backend and then there's a catch-all at the end which basically skip the tests as there is no fused attn backend avalable.
However, the reason for this to be here is to give a meaningful reason as to why a test is being skipped as compared to a generic "Unsupported inputs combination or device compute capability." message which does not qualify the reason for the skip. Unfortunately, on the JAX attn side we do not log the reason for disabling fused attn in the feature code like we have on the Pytorch side in d_p_a/utils.py. So there is no way for the user to know why the test was skipped. Hence, we need to rely on test code to log this on the JAX side.
I'd suggest we leave this in here for now. And when your PR for generating log messages in the C++ level when selecting the attn backend is ready, I can plumb it through onto the JAX side and then as part of that clean up, get rid of all the skip messages in check_configs()
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
for more information, see https://pre-commit.ci
…n fused attn Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
e317f99 to
e08e9e8
Compare
| // 9.23: d_qk = d_v = 256 + SM10x (cuDNN FE 1.24 / BE 9.23+) + bprop + non-paged | ||
| (head_dim_qk == 256 && head_dim_v == 256 && is_training && sm_arch_ >= 100 && | ||
| sm_arch_ < 110 && cudnn_runtime_version >= 92300 && | ||
| layout_group != NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD && | ||
| // The FE forces this path onto the deterministic bprop algorithm, which on | ||
| // Blackwell rejects dBias, dropout, and ALiBi (and supports vanilla softmax only). | ||
| bias_type == NVTE_Bias_Type::NVTE_NO_BIAS && dropout == 0.0 && | ||
| softmax_type == NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX && | ||
| // Non-causal D=256 supports only full-window attention; SWA is allowed only for causal masks. | ||
| ((window_size_left == -1 && window_size_right == -1) || | ||
| ((attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK || | ||
| attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK || | ||
| attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK || | ||
| attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK) && | ||
| (window_size_right == -1 || window_size_right == 0)))) || |
There was a problem hiding this comment.
D=256 BWD condition includes THD layout, causing a hard runtime exception
The new condition only excludes NVTE_Paged_KV_HD_HD_HD but does not exclude THD-format layouts. NVTE_THD_T2HD maps to layout_group = NVTE_HD_2HD and qkv_format = NVTE_THD, both of which pass all guards here and in the outer flag_arb checks (the qkv_format check at line 417 allows THD when sm_arch_ >= 90, which is true for SM10x). So nvte_get_fused_attn_backend returns NVTE_F16_arbitrary_seqlen for full-window THD + D=256 + SM10x + cuDNN ≥ 9.23, claiming support — but cuDNN 9.23 fails to build an execution plan for this layout, and NVTE_CHECK_CUDNN_FE on lines 421–422 of fused_attn_f16_arbitrary_seqlen.cu will throw a hard exception. The JAX xfail test documents the failure, but any production user with THD + D=256 training will hit an unrecoverable runtime error rather than a graceful backend fallback. Adding qkv_format != NVTE_QKV_Format::NVTE_THD to this condition would fix the backend selector; the JAX xfail test would then SKIP instead of XFAIL (which could be separately handled if you want to preserve the sentinel behaviour).
There was a problem hiding this comment.
This support is forwarding looking, i.e., support is added for THD and BSHD in the TE common fused attn backend checking code, however, the PR is still waiting on cuDNN to fix support for THD.
The current PR will not be merged as is. One of two things will happen:
- cuDNN will fix THD support and only then will this PR be merged (most likely) - after fixing the XFAIL for THD cases to skips for a specific cuDNNv version
- cuDNN will not fix this soon in which case I will switch the support to BSHD only prior to merging this PR
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
…ersions Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
for more information, see https://pre-commit.ci
Description
Support for D=256 BWD for Blackwell CC 10x via the C++ API (which TE uses) was added in cuDNN 9.23 + cuDNN FE 1.24. Enabling this support in TE attention
Type of change
Changes
Add guard when picking the backend (sub backend) in TE common.
Add tests for D=256 case in TE PyT and TE JAX
Checklist: